# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Script to generate SLURM job scripts from Jinja2 templates.
"""

import argparse
import logging
import os
import pathlib
import subprocess
import tempfile
from datetime import datetime

from jinja2 import Template


def print_welcome_message(job_ids: list[str], log_dir_name: str):
    """Print a clean welcome message with job information."""

    _ = f"{', '.join(job_ids)}"
    print(
        f"""
🚀 Welcome! We hope you enjoy your time on our GB200 NVL72.

Your logs for this submitted job will be available in logs/{log_dir_name}
You can access them by running:

    cd logs/{log_dir_name}

You can view all of the prefill/decode worker logs by running:

    tail -f *_decode_*.err *_prefill_*.err

To kick off the benchmark we suggest opening up a new terminal, SSH-ing
into the login node, and running the srun command that is found at the
bottom of the log.out. You can find it by running:

    cat log.out

Enjoy :)
- NVIDIA
"""
    )


def setup_logging(level: int = logging.INFO) -> None:
    logging.basicConfig(
        level=level,
        format="%(asctime)s| %(name)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )


def generate_job_script(template_path, output_path, **kwargs):
    """Generate a job script from template with given parameters."""
    with open(template_path, "r") as f:
        template = Template(f.read())

    rendered_script = template.render(**kwargs)
    with open(output_path, "w") as f:
        f.write(rendered_script)

    return output_path, rendered_script


def submit_job(job_script_path, extra_slurm_args=[]):
    """
    Submit the job script to SLURM and extract the job ID from the output.

    Returns:
        The job ID of the submitted job.
    """
    try:
        command = (
            ["sbatch"]
            + ["--" + x for x in extra_slurm_args]
            + [
                job_script_path,
            ]
        )
        result = subprocess.run(command, capture_output=True, text=True, check=True)
        output_lines = result.stdout.strip().split("\n")

        # sbatch typically outputs: "Submitted batch job JOBID"
        job_id = output_lines[-1].split()[-1]
        logging.info(f"Job submitted successfully with ID: {job_id}")
        return job_id
    except subprocess.CalledProcessError as e:
        logging.error(f"Error submitting job: {e}")
        logging.error(f"stderr: {e.stderr}")
        raise
    except (IndexError, ValueError):
        logging.error(f"Error parsing job ID from sbatch output: {result.stdout}")
        raise


def _get_available_gpu_types() -> list[str]:
    """Discover available GPU types by scanning scripts directory structure.

    Looks for scripts in: scripts/{gpu_type}/{agg,disagg}/*.sh
    """
    script_dir = pathlib.Path(__file__).parent / "scripts"
    gpu_types = set()

    # Scan for GPU type directories (directories that contain agg/ or disagg/)
    for gpu_dir in script_dir.iterdir():
        if not gpu_dir.is_dir():
            continue

        # Check if this directory has agg/ or disagg/ subdirectories
        has_agg = (gpu_dir / "agg").is_dir()
        has_disagg = (gpu_dir / "disagg").is_dir()

        if has_agg or has_disagg:
            gpu_types.add(gpu_dir.name)

    return sorted(list(gpu_types))


def _parse_command_line_args(args: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Generate and submit SLURM job scripts"
    )

    # Get available GPU types dynamically
    available_gpu_types = _get_available_gpu_types()

    # Template parameters
    parser.add_argument("--job-name", default="dynamo_setup", help="SLURM job name")
    parser.add_argument("--account", required=True, help="SLURM account")
    parser.add_argument("--model-dir", required=True, help="Model directory path")
    parser.add_argument("--config-dir", required=True, help="Config directory path")
    parser.add_argument("--container-image", required=True, help="Container image")
    parser.add_argument(
        "--time-limit", default="04:00:00", help="Time limit (HH:MM:SS)"
    )
    parser.add_argument(
        "--prefill-nodes", type=int, default=None, help="Number of prefill nodes"
    )
    parser.add_argument(
        "--decode-nodes", type=int, default=None, help="Number of decode nodes"
    )
    parser.add_argument(
        "--prefill-workers", type=int, default=None, help="Number of prefill workers"
    )
    parser.add_argument(
        "--decode-workers", type=int, default=None, help="Number of decode workers"
    )
    parser.add_argument(
        "--agg-nodes", type=int, default=None, help="Number of aggregated worker nodes"
    )
    parser.add_argument(
        "--agg-workers", type=int, default=None, help="Number of aggregated workers"
    )
    parser.add_argument(
        "--gpus-per-node", type=int, default=8, help="Number of GPUs per node"
    )
    parser.add_argument(
        "--network-interface", default="eth3", help="Network interface to use"
    )
    parser.add_argument(
        "--gpu-type",
        choices=available_gpu_types,
        default=available_gpu_types[0] if available_gpu_types else None,
        help=f"GPU type to use. Available types: {', '.join(available_gpu_types)}",
    )
    parser.add_argument(
        "--script-variant",
        type=str,
        default="default",
        help="Script variant to use (e.g., 'default', 'optim', 'decode-optim'). Defaults to 'default.sh'",
    )

    parser.add_argument(
        "--partition",
        default="batch",
        help="SLURM partition to use",
    )
    parser.add_argument(
        "--enable-multiple-frontends",
        action="store_true",
        help="Enable multiple frontend architecture with nginx load balancer",
    )
    parser.add_argument(
        "--num-additional-frontends",
        type=int,
        default=0,
        help="Number of additional frontend nodes (beyond the first frontend on node 1)",
    )

    parser.add_argument(
        "--use-init-location",
        action="store_true",
        help="Whether we use '--init-expert-locations' json files",
    )

    parser.add_argument(
        "--profiler",
        type=str,
        help="Profiler configurations. Example: "
        + '"type=vllm; isl=8192; osl=1024; concurrencies=16x2048x4096x8192; req-rate=inf"',
    )

    parser.add_argument(
        "--extra-slurm-args",
        action="append",
        default=[],
        help="Extra slurm arguments, remove the '--' prefix. Example: --extra-slurm-args dependency=afterok:<x>",
    )

    parser.add_argument(
        "--retries",
        type=int,
        default=0,
        help="Tries to launch the job multiple times to catch transient errors",
    )

    parser.add_argument(
        "--disable-config-dump",
        action="store_false",
        dest="enable_config_dump",
        default=True,
        help="Disable dumping config to file on each node (default: config dump is enabled)",
    )

    parser.add_argument(
        "--run-in-ci",
        action="store_true",
        help="Run in CI mode - use binaries from /configs/ for nats/etcd and install dynamo wheel",
    )

    return parser.parse_args(args)


def _validate_args(args: argparse.Namespace) -> None:
    """Validate arguments and ensure aggregated and disaggregated args are mutually exclusive."""
    has_disagg_args = any(
        [
            args.prefill_nodes is not None,
            args.decode_nodes is not None,
            args.prefill_workers is not None,
            args.decode_workers is not None,
        ]
    )
    has_agg_args = any(
        [
            args.agg_nodes is not None,
            args.agg_workers is not None,
        ]
    )

    if has_disagg_args and has_agg_args:
        raise ValueError(
            "Cannot specify both aggregated (--agg-nodes, --agg-workers) and "
            "disaggregated (--prefill-nodes, --decode-nodes, --prefill-workers, --decode-workers) arguments"
        )

    if has_disagg_args:
        # Validate disaggregated args
        if args.prefill_nodes is None or args.decode_nodes is None:
            raise ValueError(
                "Disaggregated mode requires both --prefill-nodes and --decode-nodes"
            )
        if args.prefill_workers is None or args.decode_workers is None:
            raise ValueError(
                "Disaggregated mode requires both --prefill-workers and --decode-workers"
            )
        if args.prefill_nodes % args.prefill_workers != 0:
            raise ValueError(
                f"Prefill nodes ({args.prefill_nodes}) must be divisible by prefill workers ({args.prefill_workers})"
            )
        if args.decode_nodes % args.decode_workers != 0:
            raise ValueError(
                f"Decode nodes ({args.decode_nodes}) must be divisible by decode workers ({args.decode_workers})"
            )
        # Validate GPU script exists for disaggregated mode
        script_dir = pathlib.Path(__file__).parent / "scripts"
        disagg_dir = script_dir / args.gpu_type / "disagg"
        # Use script variant (defaults to "default")
        script_name = f"{args.script_variant}.sh"
        gpu_script = disagg_dir / script_name
        if not gpu_script.exists():
            raise ValueError(
                f"Disaggregated GPU script not found: {gpu_script}. Available GPU types: {', '.join(_get_available_gpu_types())}"
            )

    if has_agg_args:
        # Validate aggregated args
        if args.agg_nodes is None or args.agg_workers is None:
            raise ValueError(
                "Aggregated mode requires both --agg-nodes and --agg-workers"
            )
        if args.agg_nodes % args.agg_workers != 0:
            raise ValueError(
                f"Aggregated nodes ({args.agg_nodes}) must be divisible by aggregated workers ({args.agg_workers})"
            )
        # Validate aggregated GPU script exists
        script_dir = pathlib.Path(__file__).parent / "scripts"
        # Remove any -prefill or -decode suffix if present
        base_gpu_type = args.gpu_type.replace("-prefill", "").replace("-decode", "")
        agg_dir = script_dir / base_gpu_type / "agg"
        # Use script variant (defaults to "default")
        script_name = f"{args.script_variant}.sh"
        agg_gpu_script = agg_dir / script_name
        if not agg_gpu_script.exists():
            raise ValueError(
                f"Aggregated GPU script not found: {agg_gpu_script}. Available GPU types: {', '.join(_get_available_gpu_types())}"
            )

    if not has_disagg_args and not has_agg_args:
        raise ValueError(
            "Must specify either aggregated (--agg-nodes, --agg-workers) or "
            "disaggregated (--prefill-nodes, --decode-nodes, --prefill-workers, --decode-workers) arguments"
        )


def main(input_args: list[str] | None = None):
    setup_logging()
    args = _parse_command_line_args(input_args)

    # Validate arguments
    _validate_args(args)

    # Determine mode and set defaults
    is_aggregated = args.agg_nodes is not None

    if is_aggregated:
        agg_nodes = args.agg_nodes
        agg_workers = args.agg_workers
        prefill_nodes = 0
        decode_nodes = 0
        prefill_workers = 0
        decode_workers = 0
        total_nodes = agg_nodes
    else:
        prefill_nodes = args.prefill_nodes
        decode_nodes = args.decode_nodes
        prefill_workers = args.prefill_workers
        decode_workers = args.decode_workers
        agg_nodes = 0
        agg_workers = 0
        total_nodes = prefill_nodes + decode_nodes

    # Validation for multiple frontends
    if args.enable_multiple_frontends:
        if args.num_additional_frontends < 0:
            raise ValueError("Number of additional frontends cannot be negative")

    # parse profiler configs
    profiler_config = {}
    if args.profiler:
        for key_val_pair in args.profiler.split("; "):
            key, val = key_val_pair.split("=")
            profiler_config[key] = val

    # validate profiler configs
    if profiler_config == {} or profiler_config["type"] == "manual":
        parsable_config = ""
        profiler_config["type"] = "manual"
    elif profiler_config["type"] in ["sglang", "vllm", "gap"]:
        parsable_config = ""
        need_keys = ["isl", "osl", "concurrencies"]
        assert all([key in profiler_config for key in need_keys])
        assert profiler_config["isl"].isnumeric()
        parsable_config = f"{parsable_config} {profiler_config['isl']}"
        assert profiler_config["osl"].isnumeric()
        parsable_config = f"{parsable_config} {profiler_config['osl']}"
        assert all([x.isnumeric() for x in profiler_config["concurrencies"].split("x")])
        parsable_config = f"{parsable_config} {profiler_config['concurrencies']}"

        if profiler_config["type"] in ["sglang", "vllm"]:
            assert "req-rate" in profiler_config
            assert (
                profiler_config["req-rate"] == "inf"
                or profiler_config["req-rate"].isnumeric()
            )
            parsable_config = f"{parsable_config} {profiler_config['req-rate']}"
    else:
        assert False, profiler_config["type"]

    # Generate timestamp for log directory naming
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Select template based on mode
    if is_aggregated:
        template_path = "job_script_template_agg.j2"
    else:
        template_path = "job_script_template_disagg.j2"

    template_vars = {
        "job_name": args.job_name,
        "total_nodes": total_nodes,
        "account": args.account,
        "time_limit": args.time_limit,
        "prefill_nodes": prefill_nodes,
        "decode_nodes": decode_nodes,
        "prefill_workers": prefill_workers,
        "decode_workers": decode_workers,
        "agg_nodes": agg_nodes,
        "agg_workers": agg_workers,
        "is_aggregated": is_aggregated,
        "model_dir": args.model_dir,
        "config_dir": args.config_dir,
        "container_image": args.container_image,
        "gpus_per_node": args.gpus_per_node,
        "network_interface": args.network_interface,
        "gpu_type": args.gpu_type,
        "script_variant": args.script_variant,
        "partition": args.partition,
        "enable_multiple_frontends": args.enable_multiple_frontends,
        "num_additional_frontends": args.num_additional_frontends,
        "use_init_location": args.use_init_location,
        "do_profile": profiler_config["type"] != "manual",
        "profiler_type": profiler_config["type"],
        "profiler_arg": parsable_config,
        "timestamp": timestamp,
        "enable_config_dump": args.enable_config_dump,
        "run_in_ci": args.run_in_ci,
    }

    # Create temporary file for sbatch script
    temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False)
    temp_path = temp_file.name
    temp_file.close()

    try:
        _, rendered_script = generate_job_script(
            template_path, temp_path, **template_vars
        )

        submitted_job_ids = []
        job_id = submit_job(temp_path, args.extra_slurm_args)
        submitted_job_ids.append(job_id)

        # Create log directory with new naming format IMMEDIATELY after submission
        # SLURM will write log.out/log.err to this directory when job starts
        if is_aggregated:
            log_dir_name = f"{job_id}_{agg_workers}A_{timestamp}"
        else:
            log_dir_name = f"{job_id}_{prefill_workers}P_{decode_workers}D_{timestamp}"
        log_dir_path = os.path.join("logs", log_dir_name)
        os.makedirs(log_dir_path, exist_ok=True)

        # Save rendered sbatch script
        sbatch_script_path = os.path.join(log_dir_path, "sbatch_script.sh")
        with open(sbatch_script_path, "w") as f:
            f.write(rendered_script)
        logging.info(f"Saved rendered sbatch script to {sbatch_script_path}")

        # retries logic
        if args.retries > 0:
            extra_slurm_args_without_dependencies = [
                x for x in args.extra_slurm_args if "dependency" not in x
            ]
            for _ in range(args.retries):
                dependencies = ",".join(
                    [f"afternotok:{job}" for job in submitted_job_ids]
                )
                slurm_args = extra_slurm_args_without_dependencies + [
                    f"dependency={dependencies}"
                ]
                job_id = submit_job(temp_path, slurm_args)
                submitted_job_ids.append(job_id)

                # Save script for retry job as well
                if is_aggregated:
                    retry_log_dir_name = f"{job_id}_{agg_workers}A_{timestamp}"
                else:
                    retry_log_dir_name = (
                        f"{job_id}_{prefill_workers}P_{decode_workers}D_{timestamp}"
                    )
                retry_log_dir_path = os.path.join("logs", retry_log_dir_name)
                os.makedirs(retry_log_dir_path, exist_ok=True)
                retry_sbatch_script_path = os.path.join(
                    retry_log_dir_path, "sbatch_script.sh"
                )
                with open(retry_sbatch_script_path, "w") as f:
                    f.write(rendered_script)
                logging.info(
                    f"Saved rendered sbatch script to {retry_sbatch_script_path}"
                )

        print_welcome_message(submitted_job_ids, log_dir_name)
    finally:
        # Clean up temporary file
        try:
            os.unlink(temp_path)
        except OSError:
            pass


if __name__ == "__main__":
    main()
